%%writefile mixture_of_experts.py
import numpy as np
import torch
device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.mps.is_available() else 'cpu'
device = torch.device(device)
print('Using device', device)
######################################################################
class Expert(torch.nn.Module):
def __init__(self, n_inputs, hidden_units_by_layer, n_outputs):
super().__init__()
self.hidden_units_by_layer = hidden_units_by_layer
self.n_inputs = n_inputs
self.n_outputs = n_outputs
self.hidden_layers = torch.nn.ModuleList()
for nh in hidden_units_by_layer:
self.hidden_layers.append( torch.nn.Sequential(
torch.nn.Linear(n_inputs, nh),
torch.nn.Tanh()))
n_inputs = nh
self.output_layer = torch.nn.Sequential(
torch.nn.Linear(n_inputs, n_outputs),
torch.nn.Softmax())
def forward(self, x):
for layer in self.hidden_layers:
x = layer(x)
out = self.output_layer(x)
return out
class Expert(torch.nn.Module):
def __init__(self, n_inputs, hidden_units_by_layer, n_outputs):
super().__init__()
self.hidden_units_by_layer = hidden_units_by_layer
self.n_inputs = n_inputs
self.n_outputs = n_outputs
self.hidden_layers = torch.nn.ModuleList()
for nh in hidden_units_by_layer:
self.hidden_layers.append( torch.nn.Sequential(
torch.nn.Linear(n_inputs, nh),
torch.nn.Tanh()))
n_inputs = nh
self.output_layer = torch.nn.Sequential(
torch.nn.Linear(n_inputs, n_outputs),
torch.nn.Softmax())
def forward(self, x):
for layer in self.hidden_layers:
x = layer(x)
out = self.output_layer(x)
return out
######################################################################
class Gate(torch.nn.Module):
def __init__(self, n_inputs, hidden_units_by_layer, n_outputs):
super().__init__()
self.hidden_units_by_layer = hidden_units_by_layer
self.n_inputs = n_inputs
self.n_outputs = n_outputs
self.hidden_layers = torch.nn.ModuleList()
for nh in hidden_units_by_layer:
self.hidden_layers.append( torch.nn.Sequential(
torch.nn.Linear(n_inputs, nh),
torch.nn.Tanh()))
n_inputs = nh
self.output_layer = torch.nn.Sequential(
torch.nn.Linear(n_inputs, n_outputs),
torch.nn.Softmax())
def forward(self, x):
for layer in self.hidden_layers:
x = layer(x)
out = self.output_layer(x)
return out
######################################################################
class MoE(torch.nn.Module):
def __init__(self, n_inputs, hidden_units_by_layer_experts, hidden_units_by_layer_gate, n_experts, n_outputs):
super().__init__()
self.n_inputs = n_inputs
self.hidden_units_by_layer_experts = hidden_units_by_layer_experts
self.hidden_units_by_layer_gate = hidden_units_by_layer_gate
self.n_experts = n_experts
self.n_outputs = n_outputs
self.Xmeans = None
self.experts = torch.nn.ModuleList()
for expert in range(n_experts):
self.experts.append(Expert(n_inputs, hidden_units_by_layer_experts, n_outputs))
self.gate = Gate(n_inputs, hidden_units_by_layer_gate, n_experts)
self.output_layer = torch.nn.Linear(n_outputs, n_outputs)
def forward(self, x):
self.Y_experts = torch.stack([expert(x) for expert in self.experts], dim=-1)
self.Y_gate = self.gate(x)
self.Y_experts_weighted = self.Y_experts * self.Y_gate[:, None, :]
self.Y_experts_weighted = torch.mean(self.Y_experts_weighted, dim=-1)
out = self.output_layer(self.Y_experts_weighted)
return out
def __repr__(self):
return f'''MoE(n_inputs={self.n_inputs},
hidden_units_by_layer={self.hidden_units_by_layer_experts},
hidden_units_by_layer_gate={self.hidden_units_by_layer_gate},
n_experts={self.n_experts},
n_outputs={self.n_outputs})'''
def train(self, Xtrain, Ttrain, Xval, Tval, n_epochs, learning_rate, batch_size=-1):
if self.Xmeans is None:
self.Xmeans = torch.mean(Xtrain, axis=0)
self.Xstds = torch.std(Xtrain, axis=0)
self.Xstds[self.Xstds == 0] = 1
Xtrain = (Xtrain - self.Xmeans) / self.Xstds
Xval = (Xval - self.Xmeans) / self.Xstds
if batch_size == -1:
batch_size = Xtrain.shape[0]
loss_func = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(self.parameters(), lr = learning_rate)
# optimizer = torch.optim.SGD(self.parameters(), lr = learning_rate)
self.error_trace = []
self.percent_correct= []
rows = np.arange(Xtrain.shape[0])
for epoch in range(n_epochs):
np.random.shuffle(rows)
for first in range(0, Xtrain.shape[0], batch_size):
Xtrain_batch = Xtrain[rows[first:first + batch_size], :]
Ttrain_batch = Ttrain[rows[first:first + batch_size]]
# Forward pass
outputs = self.forward(Xtrain_batch)
loss = loss_func(outputs, Ttrain_batch)
# Backward and optimize
optimizer.zero_grad()
loss.backward()
optimizer.step()
self.error_trace.append(np.exp(-loss.detach().cpu()))
Yc = outputs.cpu().detach().numpy().argmax(1).reshape(-1)
Tc = Ttrain_batch.cpu().numpy()
percent = np.mean(Yc == Tc) * 100
self.percent_correct.append(percent)
if (epoch + 1) % (n_epochs // 10) == 0:
print(f'Epoch {epoch + 1}, {percent:.1f}')
def use(self, X):
X = (X - self.Xmeans) / self.Xstds
outputs = self.forward(X)
probs = torch.nn.functional.softmax(outputs, dim=-1).cpu().detach()
Yc = outputs.cpu().detach().numpy().argmax(1).reshape(-1, 1)
return Yc, probs
Overwriting mixture_of_experts.py
%%writefile drawdigit.py
import numpy as np
import matplotlib.pyplot as plt
import torch
import gzip
import pickle
import pandas
from matplotlib.backend_bases import MouseButton
import mixture_of_experts as moe
device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.mps.is_available() else 'cpu'
device = torch.device(device)
device = 'cpu'
print('Using device', device)
####### Get MNIST data
if True:
with gzip.open('mnist.pkl.gz', 'rb') as f:
train_set, valid_set, test_set = pickle.load(f, encoding='latin1')
Xtrain = train_set[0]
Ttrain = train_set[1]
Xval = valid_set[0]
Tval = valid_set[1]
Xtest = test_set[0]
Ttest = test_set[1]
n_inputs = Xtrain.shape[1]
n_classes = len(np.unique(Ttrain))
Xtrain = torch.from_numpy(Xtrain.astype(np.float32)).to(device)
Ttrain = torch.from_numpy(Ttrain.astype(np.int64)).to(device)
Xval = torch.from_numpy(Xval.astype(np.float32)).to(device)
Tval = torch.from_numpy(Tval.astype(np.int64)).to(device)
Xtest = torch.from_numpy(Xtest.astype(np.float32)).to(device)
Ttest = torch.from_numpy(Ttest.astype(np.int64)).to(device)
##### Make and train an MOE model
n_experts = 6
nnet = moe.MoE(n_inputs, [2], [20], n_experts, n_classes).to(device)
nnet.train(Xtrain, Ttrain, Xval, Tval, 100, 0.02, batch_size=1000)
else:
nnet = None
n_experts = 6
#### Now draw digit and classify it
image = np.zeros((28, 28)).astype(np.uint8)
figure = plt.figure(figsize=(12, 8))
plt.subplot(1, 3, 1)
plt.title('Mouse left down to draw\nMouse right to clear')
drawing_axis = plt.gca()
drawimage = drawing_axis.imshow(image, vmin=0, vmax=255, cmap='gray_r')
plt.subplot(1, 3, 2)
probdata = plt.bar(range(10), np.zeros(10))
plt.ylim(0, 1)
# prob_axis = plt.gca()
plt.subplot(1, 3, 3)
gatedata = plt.bar(range(n_experts), np.zeros(n_experts))
plt.ylim(0, 1)
#gate_axis = plt.gca()
def update_img_with_matplotlib():
drawimage.set_data(image)
plt.draw()
# plt.pause(1e-3)
# drawing_axis.imshow(-image, vmin=-1, vmax=0, cmap='gray')
if nnet:
predicted_class, probs = nnet.use(torch.from_numpy((image/255).reshape(1, -1).astype(np.float32)).to(device))
probs = probs.reshape(-1)
gates = nnet.Y_gate.reshape(-1).detach().cpu()
else:
probs = np.random.uniform(0, 1, 10)
gates = np.random.uniform(0, 1, n_experts)
# prob_axis.cla()
for i in range(10):
probdata[i].set_height(probs[i])
# prob_axis.bar(range(10), probs)
# prob_axis.set_title('Digit Probabilities')
# prob_axis.set_ylim(0, 1)
# prob_axis.set_xlabel('Digit')
# gate_axis.cla()
for i in range(n_experts):
gatedata[i].set_height(gates[i])
# gate_axis.bar(range(n_experts), gates)
# gate_axis.set_title('Gate Outputs')
figure.canvas.draw()
def clip(v):
return 0 if v < 0 else 27 if v > 27 else v
def on_click(event):
if event.inaxes != drawing_axis:
return
if event.button is MouseButton.RIGHT:
image[:, :] = 0
update_img_with_matplotlib()
plt.pause(1e-2)
def mouse_motion_event(event):
if event.button is not None:
if event.inaxes != drawing_axis:
return
y, x = int(round(event.xdata)), int(round(event.ydata))
y1, y2 = clip(y - 1), clip(y + 1)
x1, x2 = clip(x - 1), clip(x + 1)
image[x - 1:x + 1, y - 1:y + 1] = 255
update_img_with_matplotlib()
update_img_with_matplotlib()
figure.canvas.mpl_connect('motion_notify_event', mouse_motion_event)
figure.canvas.mpl_connect('button_press_event', on_click)
plt.show()
Overwriting drawdigit.py
Now run
python drawdigit.py
from the command line.